#include <types.h>
#include <stdio.h>
#include <trap.h>
#include <picirq.h>
#include <fs.h>
#include <ide.h>
#include <arch.h>
#include <sem.h>
#include <assert.h>
#include <kio.h>
#include <ramdisk.h>

#define ISA_DATA                0x00
#define ISA_ERROR               0x01
#define ISA_PRECOMP             0x01
#define ISA_CTRL                0x02
#define ISA_SECCNT              0x02
#define ISA_SECTOR              0x03
#define ISA_CYL_LO              0x04
#define ISA_CYL_HI              0x05
#define ISA_SDH                 0x06
#define ISA_COMMAND             0x07
#define ISA_STATUS              0x07

#define IDE_BSY                 0x80
#define IDE_DRDY                0x40
#define IDE_DF                  0x20
#define IDE_DRQ                 0x08
#define IDE_ERR                 0x01

#define IDE_CMD_READ            0x20
#define IDE_CMD_WRITE           0x30
#define IDE_CMD_IDENTIFY        0xEC

#define IDE_IDENT_SECTORS       20
#define IDE_IDENT_MODEL         54
#define IDE_IDENT_CAPABILITIES  98
#define IDE_IDENT_CMDSETS       164
#define IDE_IDENT_MAX_LBA       120
#define IDE_IDENT_MAX_LBA_EXT   200

#define IO_BASE0                0x1F0
#define IO_BASE1                0x170
#define IO_CTRL0                0x3F4
#define IO_CTRL1                0x374

#define MAX_IDE                 4
#define MAX_NSECS               128
#define MAX_DISK_NSECS          0x10000000U
#define VALID_IDE(ideno)        (((ideno) >= 0) && ((ideno) < MAX_IDE) && (ide_devices[ideno].valid))

static struct {
	const unsigned short base;	// I/O Base
	const unsigned short ctrl;	// Control Base
	semaphore_t sem;
} channels[2] = {
	{
	IO_BASE0, IO_CTRL0}, {
IO_BASE1, IO_CTRL1},};

static void lock_channel(unsigned short ideno)
{
	down(&(channels[ideno >> 1].sem));
}

static void unlock_channel(unsigned short ideno)
{
	up(&(channels[ideno >> 1].sem));
}

#define IO_BASE(ideno)          (channels[(ideno) >> 1].base)
#define IO_CTRL(ideno)          (channels[(ideno) >> 1].ctrl)

static struct ide_device ide_devices[MAX_IDE];

static int ide_wait_ready(unsigned short iobase, bool check_error)
{
	int r;
	while ((r = inb(iobase + ISA_STATUS)) & IDE_BSY)
		/* nothing */ ;
	if (check_error && (r & (IDE_DF | IDE_ERR)) != 0) {
		return -1;
	}
	return 0;
}

void ide_init(void)
{
	static_assert((SECTSIZE % 4) == 0);
	if(initrd_begin){
		struct ide_device *dev = &ide_devices[DISK0_DEV_NO];
		ramdisk_init_struct(dev);
		ramdisk_init(dev);
		return;
	}
	unsigned short ideno, iobase;
	for (ideno = 0; ideno < MAX_IDE; ideno++) {
		/* assume that no device here */
		ide_devices[ideno].valid = 0;

		iobase = IO_BASE(ideno);

		/* wait device ready */
		ide_wait_ready(iobase, 0);

		/* step1: select drive */
		outb(iobase + ISA_SDH, 0xE0 | ((ideno & 1) << 4));
		ide_wait_ready(iobase, 0);

		/* step2: send ATA identify command */
		outb(iobase + ISA_COMMAND, IDE_CMD_IDENTIFY);
		ide_wait_ready(iobase, 0);

		/* step3: polling */
		if (inb(iobase + ISA_STATUS) == 0
		    || ide_wait_ready(iobase, 1) != 0) {
			continue;
		}

		/* device is ok */
		ide_devices[ideno].valid = 1;

		/* read identification space of the device */
		unsigned int buffer[128];
		insl(iobase + ISA_DATA, buffer,
		     sizeof(buffer) / sizeof(unsigned int));

		unsigned char *ident = (unsigned char *)buffer;
		unsigned int sectors;
		unsigned int cmdsets =
		    *(unsigned int *)(ident + IDE_IDENT_CMDSETS);
		/* device use 48-bits or 28-bits addressing */
		if (cmdsets & (1 << 26)) {
			sectors =
			    *(unsigned int *)(ident + IDE_IDENT_MAX_LBA_EXT);
		} else {
			sectors = *(unsigned int *)(ident + IDE_IDENT_MAX_LBA);
		}
		ide_devices[ideno].sets = cmdsets;
		ide_devices[ideno].size = sectors;

		/* check if supports LBA */
		assert((*(unsigned short *)(ident + IDE_IDENT_CAPABILITIES) &
			0x200) != 0);

		unsigned char *model = ide_devices[ideno].model, *data =
		    ident + IDE_IDENT_MODEL;
		unsigned int i, length = 40;
		for (i = 0; i < length; i += 2) {
			model[i] = data[i + 1], model[i + 1] = data[i];
		}
		do {
			model[i] = '\0';
		} while (i-- > 0 && model[i] == ' ');

		kprintf("ide %d: %10u(sectors), '%s'.\n", ideno,
			ide_devices[ideno].size, ide_devices[ideno].model);
	}

	// enable ide interrupt
	pic_enable(IRQ_IDE1);
	pic_enable(IRQ_IDE2);

	sem_init(&(channels[0].sem), 1);
	sem_init(&(channels[1].sem), 1);
}

bool ide_device_valid(unsigned short ideno)
{
	return VALID_IDE(ideno);
}

size_t ide_device_size(unsigned short ideno)
{
	if (ide_device_valid(ideno)) {
		return ide_devices[ideno].size;
	}
	return 0;
}

int ide_read_secs(unsigned short ideno, uint32_t secno, void *dst, size_t nsecs)
{
	assert(nsecs <= MAX_NSECS && VALID_IDE(ideno));
	assert(secno < MAX_DISK_NSECS && secno + nsecs <= MAX_DISK_NSECS);
	if(ide_devices[ideno].ramdisk)
		return ramdisk_read(&ide_devices[ideno], secno, dst, nsecs);
	unsigned short iobase = IO_BASE(ideno), ioctrl = IO_CTRL(ideno);

	lock_channel(ideno);

	ide_wait_ready(iobase, 0);

	// generate interrupt
	outb(ioctrl + ISA_CTRL, 0);
	outb(iobase + ISA_SECCNT, nsecs);
	outb(iobase + ISA_SECTOR, secno & 0xFF);
	outb(iobase + ISA_CYL_LO, (secno >> 8) & 0xFF);
	outb(iobase + ISA_CYL_HI, (secno >> 16) & 0xFF);
	outb(iobase + ISA_SDH,
	     0xE0 | ((ideno & 1) << 4) | ((secno >> 24) & 0xF));
	outb(iobase + ISA_COMMAND, IDE_CMD_READ);

	int ret = 0;
	for (; nsecs > 0; nsecs--, dst += SECTSIZE) {
		if ((ret = ide_wait_ready(iobase, 1)) != 0) {
			goto out;
		}
		insl(iobase, dst, SECTSIZE / sizeof(uint32_t));
	}

out:
	unlock_channel(ideno);
	return ret;
}

int
ide_write_secs(unsigned short ideno, uint32_t secno, const void *src,
	       size_t nsecs)
{
	assert(nsecs <= MAX_NSECS && VALID_IDE(ideno));
	assert(secno < MAX_DISK_NSECS && secno + nsecs <= MAX_DISK_NSECS);

	if(ide_devices[ideno].ramdisk)
		return ramdisk_write(&ide_devices[ideno], secno, src, nsecs);
	unsigned short iobase = IO_BASE(ideno), ioctrl = IO_CTRL(ideno);

	lock_channel(ideno);

	ide_wait_ready(iobase, 0);

	// generate interrupt
	outb(ioctrl + ISA_CTRL, 0);
	outb(iobase + ISA_SECCNT, nsecs);
	outb(iobase + ISA_SECTOR, secno & 0xFF);
	outb(iobase + ISA_CYL_LO, (secno >> 8) & 0xFF);
	outb(iobase + ISA_CYL_HI, (secno >> 16) & 0xFF);
	outb(iobase + ISA_SDH,
	     0xE0 | ((ideno & 1) << 4) | ((secno >> 24) & 0xF));
	outb(iobase + ISA_COMMAND, IDE_CMD_WRITE);

	int ret = 0;
	for (; nsecs > 0; nsecs--, src += SECTSIZE) {
		if ((ret = ide_wait_ready(iobase, 1)) != 0) {
			goto out;
		}
		outsl(iobase, src, SECTSIZE / sizeof(uint32_t));
	}

out:
	unlock_channel(ideno);
	return ret;
}
